#! /usr/bin/env python
"""
@author: Michaël Baudin (EDF), Antoine Dumas (Phiméca)


Checks the distribution of the sensitivity indices estimators.

We compare the asymptotic distribution with an empirical sample of Sobol'
indices estimates.

We want to check that the distribution computed by the library is OK.
When we estimate this distribution depending on the asymptotic estimator,
this distribution is Gaussian.
Otherwise, it is computed using kernel smoothing based on a bootstrap re-sample.
A first significant point is the mean of the distribution.
The second significant point is the variance of the distribution, which
is computed using the Delta-method when the asymptotic distribution is
used.

The reference distribution is computed based on a sample of Sobol' indices,
which is generated by repetition.
Then we compare the "Sample" distribution (from repetition) and the
"Computed" distribution from the library.
These two distributions should be close.

The asymptotic distribution can be plot by two different methods.

- We can perform a single estimation of Sobol' indices (using a given sample size)
and get the estimated distribution.
The problem with this distribution is that it can be shifted with respect to the
true value of the Sobol' index.
This can be because of the limited sample size or because of a bias in the
estimator.
This is done if mean_distribution==True.
In this case, the figures are in the directory:
"graphe_validation_mean_distribution/".

- We can repeat the estimations several times and create a new Gaussian distribution,
which mean is computed as the mean of the means of the Gaussians and
the standard deviation is equal to the mean of the standard deviations.
This allows one to slightly correct the shifting generated by the limited size of the
sample, but does not correct the bias of the estimator, if any.
This is what is done if mean_distribution==False.
In this case the figures are in the the "graphe_validation/" directory.

The sample distribution of the Sobol' indices is estimated using an histogram,
based on a Monte-Carlo sample.

The check is based on the comparison between the distribution computed by the
library using a single sample and the distribution estimated by repetition.
"""


import matplotlib as mpl

import openturns as ot
import numpy as np
from openturns.viewer import View
from gsobollib import gsobolSAExact, gsobolDistribution, gsobol
from ishigamilib import (
    ishigamiSAExact,
    ishigamiDistribution,
    ishigamiAB,
    ishigamiGSymbolic,
)
from numpy import zeros, array
import pylab as pl
import pathlib
from sobol_variance_estimators import (
    SaltelliSensitivityAlgorithm,
    MartinezSensitivityAlgorithm,
    JansenSensitivityAlgorithm,
    MauntzKucherenkoSensitivityAlgorithm,
)

mpl.use("agg")


class SensitivityConfidenceTest:
    """
    Class pour valider la variance des estimateurs de sobols par la méthode Janon.
    La distribution asymptotique est comparé à la distribution empirique obtenue
    par répétitions du calcul des estimateurs de Sobol.
    """

    def __init__(
        self,
        model,
        distribution,
        sobol_estimator,
        FOexact=None,
        TOexact=None,
        sampleSize=500,
        seed=154681,
        nrepetitions=100,
        alpha=0.95,
        savefig=False,
        plot_figure=False,
    ):

        # Le modèle test : g-sobol ou ishigami
        self.model = model
        # la distribution des paramètres du modèle test
        self.distribution = distribution
        self.dim = distribution.getDimension()
        # classe améliorée pour le calcul des indices (par ex.
        # SaltelliSensitivityAlgorithm)
        self.sobol_estimator = sobol_estimator
        # Taille du plan d'expérience de base pour estimer S et ST
        self.sampleSize = sampleSize
        # Nombre de répétition de l'expérience
        self.nrepetitions = nrepetitions
        # bootstrap confidence level
        self.alpha = alpha
        # flag pour sauvegarder les figures
        self.savefig = savefig
        # first order exact indices if known
        self.FOexact = FOexact
        # total order exact indices if known
        self.TOexact = TOexact
        # seed
        self.seed = seed

        (
            sampleFirst,
            sampleTotal,
            foInterval,
            toInterval,
            distFirstCol,
            distTotalCol,
        ) = self.compute_sample_indices()
        if plot_figure:
            self.compare_last_repetition(
                sampleFirst,
                sampleTotal,
                foInterval,
                toInterval,
                distFirstCol,
                distTotalCol,
            )
            self.plot_indices_histogram(
                sampleFirst, sampleTotal, distFirstCol, distTotalCol
            )
            self.plot_indices_histogram(
                sampleFirst, sampleTotal, distFirstCol, distTotalCol, True
            )

    def compute_sample_indices(self):

        # Estimations des indices du premier ordre
        sampleFirst = zeros((self.nrepetitions, self.dim))

        # Estimations des indices totaux
        sampleTotal = zeros((self.nrepetitions, self.dim))

        # loi asymptotique
        distFirstCol = [object] * self.nrepetitions
        distTotalCol = [object] * self.nrepetitions

        # set seed of the random generator
        ot.RandomGenerator.SetSeed(self.seed)
        for i in range(self.nrepetitions):
            sobolexperiment = ot.SobolIndicesExperiment(
                self.distribution, int(self.sampleSize), False
            )
            inputDesign = sobolexperiment.generate()
            outputDesign = self.model(inputDesign)
            self.sensitivity_algorithm = self.sobol_estimator(
                inputDesign, outputDesign, int(self.sampleSize)
            )
            # self.sensitivity_algorithm = self.sobol_estimator(self.distribution,
            #                             int(self.sampleSize), self.model)
            self.sensitivity_algorithm.setConfidenceLevel(self.alpha)
            fo = self.sensitivity_algorithm.getAggregatedFirstOrderIndices()
            to = self.sensitivity_algorithm.getAggregatedTotalOrderIndices()
            # Récupère les distributions asymptotiques
            distFirstCol[
                i
            ] = self.sensitivity_algorithm.getFirstOrderAsymptoticDistribution()
            distTotalCol[
                i
            ] = self.sensitivity_algorithm.getTotalOrderAsymptoticDistribution()
            for j in range(self.dim):
                sampleFirst[i, j] = fo[j]
            for j in range(self.dim):
                sampleTotal[i, j] = to[j]

        # Récupère l'intervalle de confiance bootstrap pour le dernier échantillon
        foInterval = self.sensitivity_algorithm.getFirstOrderIndicesInterval()
        toInterval = self.sensitivity_algorithm.getTotalOrderIndicesInterval()

        # compute empirical variance
        self.std_first_empirical = ot.Sample(
            sampleFirst
        ).computeStandardDeviation()
        self.std_total_empirical = ot.Sample(
            sampleTotal
        ).computeStandardDeviation()
        return (
            sampleFirst,
            sampleTotal,
            foInterval,
            toInterval,
            distFirstCol,
            distTotalCol,
        )

    def compare_last_repetition(
        self,
        sampleFirst,
        sampleTotal,
        foInterval,
        toInterval,
        distFirstCol,
        distTotalCol,
    ):

        # récupère les valeurs des min et max des intervalles
        foIntervalMin = foInterval.getLowerBound()
        foIntervalMax = foInterval.getUpperBound()
        toIntervalMin = toInterval.getLowerBound()
        toIntervalMax = toInterval.getUpperBound()
        # Compare les intervalles bootstrap pour le dernier échantillon
        # et les quantiles issus des répétitions
        for j in range(self.dim):
            # Calcule les quantiles empiriques
            sampleFirstPerDim = ot.Sample([[v] for v in sampleFirst[:, j]])
            foMinj = sampleFirstPerDim.computeQuantile((1 - self.alpha) / 2)[0]
            foMaxj = sampleFirstPerDim.computeQuantile(1 - (1 - self.alpha) / 2)[0]
            sampleTotalPerDim = ot.Sample([[v] for v in sampleTotal[:, j]])
            toMinj = sampleTotalPerDim.computeQuantile((1 - self.alpha) / 2)[0]
            toMaxj = sampleTotalPerDim.computeQuantile(1 - (1 - self.alpha) / 2)[0]
            foAsympMinj = np.mean(
                [
                    distFirstCol[i][j].computeQuantile((1 - self.alpha) / 2)[0]
                    for i in range(self.nrepetitions)
                ]
            )
            foAsympMaxj = np.mean(
                [
                    distFirstCol[i][j].computeQuantile(1 - (1 - self.alpha) / 2)[0]
                    for i in range(self.nrepetitions)
                ]
            )
            toAsympMinj = np.mean(
                [
                    distTotalCol[i][j].computeQuantile((1 - self.alpha) / 2)[0]
                    for i in range(self.nrepetitions)
                ]
            )
            toAsympMaxj = np.mean(
                [
                    distTotalCol[i][j].computeQuantile(1 - (1 - self.alpha) / 2)[0]
                    for i in range(self.nrepetitions)
                ]
            )
            meanAsympFO = np.mean(
                [
                    distFirstCol[i][j].getStandardDeviation()[0]
                    for i in range(self.nrepetitions)
                ]
            )
            meanAsympTO = np.mean(
                [
                    distTotalCol[i][j].getStandardDeviation()[0]
                    for i in range(self.nrepetitions)
                ]
            )

            print("X%d" % (j))
            print(
                "   First standard deviation, Sample=%.5f, Asymptotic=%.5f"
                % (sampleFirstPerDim.computeStandardDeviation()[0], meanAsympFO)
            )
            print(
                "   Total standard deviation, Sample=%.5f, Asymptotic=%.5f"
                % (sampleTotalPerDim.computeStandardDeviation()[0], meanAsympTO)
            )
            print(
                "   First, Bootstrap=[%.4f,%.4f], Sample=[%.4f,%.4f], "
                "Asymptotic=[%.4f,%.4f]"
                % (
                    foIntervalMin[j],
                    foIntervalMax[j],
                    foMinj,
                    foMaxj,
                    foAsympMinj,
                    foAsympMaxj,
                )
            )
            print(
                "   Total, Bootstrap=[%.4f,%.4f], Sample=[%.4f,%.4f], "
                "Asymptotic=[%.4f,%.4f]"
                % (
                    toIntervalMin[j],
                    toIntervalMax[j],
                    toMinj,
                    toMaxj,
                    toAsympMinj,
                    toAsympMaxj,
                )
            )
            print("")

    def plot_indices_histogram(
        self,
        sampleFirst,
        sampleTotal,
        distFirstCol,
        distTotalCol,
        mean_distribution=False,
    ):
        fig, ax = pl.subplots(2, self.dim, figsize=(4 * self.dim, 8))
        fig.suptitle(
            "%s - N=%d - Repetitions = %d"
            % (self.sobol_estimator.__name__, self.sampleSize, self.nrepetitions)
        )

        # Pour chaque estimateur, compare la répartition empirique et la loi exacte
        for j in range(self.dim):
            # Indice du premier ordre
            sampleJ = sampleFirst[:, j]
            ax[0, j].hist(
                sampleJ,
                histtype="step",
                density=True,
                label="empirique - std=%.2e" % np.std(sampleJ),
            )
            if mean_distribution:
                # moyenne et ecart-type = moyenne des moyenne et écart-types
                # des lois asymptotiques
                mu = np.mean(
                    [distFirstCol[i][j].getMean()[0] for i in range(self.nrepetitions)]
                )  # valeur moyenne
                sigma = np.mean(
                    [
                        distFirstCol[i][j].getStandardDeviation()[0]
                        for i in range(self.nrepetitions)
                    ]
                )  # écart-type
                label = "asymptotique - std(mean)=%.2e" % sigma
            else:
                mu = distFirstCol[-1][j].getMean()[0]
                sigma = distFirstCol[-1][j].getStandardDeviation()[0]
                label = "asymptotique - std=%.2e" % sigma
            loiFo = ot.Normal(mu, sigma)
            View(loiFo.drawPDF(), axes=[ax[0, j]], plot_kwargs={"label": label})
            if self.FOexact is not None:
                ax[0, j].vlines(
                    self.FOexact[j],
                    0,
                    ax[0, j].get_ylim()[1],
                    label="exact - %0.3f" % self.FOexact[j],
                )
            ax[0, j].set_xlabel("S%d" % (j))
            ax[0, j].set_ylabel("Density")
            ax[0, j].legend(loc="lower right")
            # Indice du total
            sampleJ = sampleTotal[:, j]
            ax[1, j].hist(
                sampleJ,
                histtype="step",
                density=True,
                label="empirique - std=%.2e" % np.std(sampleJ),
            )
            if mean_distribution:
                # moyenne et ecart-type = moyenne des moyenne et écart-types
                # des lois asymptotiques
                mu = np.mean(
                    [distTotalCol[i][j].getMean()[0] for i in range(self.nrepetitions)]
                )  # valeur moyenne
                sigma = np.mean(
                    [
                        distTotalCol[i][j].getStandardDeviation()[0]
                        for i in range(self.nrepetitions)
                    ]
                )  # écart-type
                label = "asymptotique- std(mean)=%.2e" % sigma
            else:
                mu = distTotalCol[-1][j].getMean()[0]
                sigma = distTotalCol[-1][j].getStandardDeviation()[0]
                label = "asymptotique- std=%.2e" % sigma
            loiFo = ot.Normal(mu, sigma)
            View(loiFo.drawPDF(), axes=[ax[1, j]], plot_kwargs={"label": label})
            if self.TOexact is not None:
                ax[1, j].vlines(
                    self.TOexact[j],
                    0,
                    ax[1, j].get_ylim()[1],
                    label="exact - %0.3f" % self.TOexact[j],
                )
            ax[1, j].set_xlabel("ST%d" % (j))
            ax[1, j].set_ylabel("Density")
            ax[1, j].legend(loc="lower right")
        # fig.show()
        if mean_distribution:
            directory = "graphe_validation_mean_distribution/"
        else:
            directory = "graphe_validation/"
        # create directory if not exist
        pathlib.Path(directory).mkdir(parents=True, exist_ok=True)
        if self.savefig:
            fig.savefig(
                directory
                + "%s-%s.png" % (self.model.getName(), self.sobol_estimator.__name__),
                transparent=True,
                bbox_inches="tight",
            )


if __name__ == "__main__":

    # GSOBOL

    a = array([0, 9, 99])
    # Distribution uniforme associée au cas-test GSobol
    distribution_gsobol = gsobolDistribution(len(a))
    model_gsobol = ot.PythonFunction(len(a), 1, func_sample=lambda X: gsobol(X, a))
    model_gsobol.setName("G-Sobol")
    # Indices de sensibilité exacts
    [muexact, vexact, sexact, stexact] = gsobolSAExact(a)

    sTest_gsobol_saltelli = SensitivityConfidenceTest(
        model_gsobol,
        distribution_gsobol,
        SaltelliSensitivityAlgorithm,
        FOexact=sexact,
        TOexact=stexact,
        savefig=True,
        plot_figure=True,
    )

    sTest_gsobol_jansen = SensitivityConfidenceTest(
        model_gsobol,
        distribution_gsobol,
        JansenSensitivityAlgorithm,
        FOexact=sexact,
        TOexact=stexact,
        savefig=True,
        plot_figure=True,
    )

    sTest_gsobol_mauntz = SensitivityConfidenceTest(
        model_gsobol,
        distribution_gsobol,
        MauntzKucherenkoSensitivityAlgorithm,
        FOexact=sexact,
        TOexact=stexact,
        savefig=True,
        plot_figure=True,
    )

    sTest_gsobol_martinez = SensitivityConfidenceTest(
        model_gsobol,
        distribution_gsobol,
        MartinezSensitivityAlgorithm,
        FOexact=sexact,
        TOexact=stexact,
        savefig=True,
        plot_figure=True,
    )

    # ISHIGAMI

    model_ishigami = ishigamiGSymbolic()
    model_ishigami.setName("Ishigami")
    distribution_ishigami = ishigamiDistribution()
    # Indices de sensibilité exacts
    a, b = ishigamiAB()
    meanY, varY, S1, S2, S3, ST1, ST2, ST3 = ishigamiSAExact(a, b)
    sexact = ot.Point([S1, S2, S3])
    stexact = ot.Point([ST1, ST2, ST3])

    sTest_ishigami_saltelli = SensitivityConfidenceTest(
        model_ishigami,
        distribution_ishigami,
        SaltelliSensitivityAlgorithm,
        FOexact=sexact,
        TOexact=stexact,
        savefig=True,
        plot_figure=True,
    )

    sTest_ishigami_jansen = SensitivityConfidenceTest(
        model_ishigami,
        distribution_ishigami,
        JansenSensitivityAlgorithm,
        FOexact=sexact,
        TOexact=stexact,
        savefig=True,
        plot_figure=True,
    )

    sTest_ishigami_mauntz = SensitivityConfidenceTest(
        model_ishigami,
        distribution_ishigami,
        MauntzKucherenkoSensitivityAlgorithm,
        FOexact=sexact,
        TOexact=stexact,
        savefig=True,
        plot_figure=True,
    )

    sTest_ishigami_martinez = SensitivityConfidenceTest(
        model_ishigami,
        distribution_ishigami,
        MartinezSensitivityAlgorithm,
        FOexact=sexact,
        TOexact=stexact,
        savefig=True,
        plot_figure=True,
    )

    # POUTRE
    model_poutre = ot.SymbolicFunction(
        ["L", "b", "h", "E", "F"], ["F * L^3 / (48 * E * b * h^3 / 12)"]
    )
    model_poutre.setName("poutre")
    L = ot.LogNormal()
    L.setParameter(ot.LogNormalMuSigmaOverMu()([5.0, 0.02, 0.0]))
    b = ot.LogNormal()
    b.setParameter(ot.LogNormalMuSigmaOverMu()([0.2, 0.05, 0.0]))
    h = ot.LogNormal()
    h.setParameter(ot.LogNormalMuSigmaOverMu()([0.4, 0.05, 0.0]))
    E = ot.LogNormal()
    E.setParameter(ot.LogNormalMuSigmaOverMu()([3e4, 0.12, 0.0]))
    F = ot.LogNormal()
    F.setParameter(ot.LogNormalMuSigmaOverMu()([0.1, 0.20, 0.0]))
    distribution_poutre = ot.JointDistribution([L, b, h, E, F])

    sTest_poutre_saltelli = SensitivityConfidenceTest(
        model_poutre,
        distribution_poutre,
        SaltelliSensitivityAlgorithm,
        savefig=True,
        plot_figure=True,
    )

    sTest_poutre_jansen = SensitivityConfidenceTest(
        model_poutre,
        distribution_poutre,
        JansenSensitivityAlgorithm,
        savefig=True,
        plot_figure=True,
    )

    sTest_poutre_mauntz = SensitivityConfidenceTest(
        model_poutre,
        distribution_poutre,
        MauntzKucherenkoSensitivityAlgorithm,
        savefig=True,
        plot_figure=True,
    )

    sTest_poutre_martinez = SensitivityConfidenceTest(
        model_poutre,
        distribution_poutre,
        MartinezSensitivityAlgorithm,
        savefig=True,
        plot_figure=True,
    )

    # Aggregated Sobol

    model_aggregated = ot.SymbolicFunction(
        ["X1", "X2", "X3"],
        ["2*X1 + X2 - 3*X3 + 0.3*X1*X2", "-5*X1 + 4*X2 - 0.8*X2*X3 + 2*X3"],
    )
    model_aggregated.setName("AggregatedSobol")
    distribution_aggregated = ot.JointDistribution([ot.Uniform()] * 3)

    sTest_poutre_saltelli = SensitivityConfidenceTest(
        model_aggregated,
        distribution_aggregated,
        SaltelliSensitivityAlgorithm,
        savefig=True,
        plot_figure=True,
    )

    sTest_poutre_jansen = SensitivityConfidenceTest(
        model_aggregated,
        distribution_aggregated,
        JansenSensitivityAlgorithm,
        savefig=True,
        plot_figure=True,
    )

    sTest_poutre_mauntz = SensitivityConfidenceTest(
        model_aggregated,
        distribution_aggregated,
        MauntzKucherenkoSensitivityAlgorithm,
        savefig=True,
        plot_figure=True,
    )

    sTest_poutre_martinez = SensitivityConfidenceTest(
        model_aggregated,
        distribution_aggregated,
        MartinezSensitivityAlgorithm,
        savefig=True,
        plot_figure=True,
    )
